from MODEL import Discriminator,Generater
from DataLoader import DataLoader
from parser1 import args
import numpy as np
import torch
import random
from time import time
import torch.optim as optim
from function import initialize_weights,get_rat,get_performance
from tqdm import tqdm
import torch.nn as nn
import multiprocessing
import heapq
import os
from data.logging import Logger
from datetime import datetime
from LightGCN import Light_GCN
import torch.nn.functional as F
import matplotlib.pyplot as plt

dataloader = DataLoader(path=args.data_path + args.dataset, batch_size=args.batch_size)
Ks = eval(args.Ks)

def ranklist_by_heapq(user_pos_test, test_items, rating, Ks, is_gcn=False):
    item_score = {}
    if is_gcn:
        for i,item in enumerate(test_items):
            item_score[item] = rating[i]
    else:
        for i in test_items:
            item_score[i] = rating[i]

    K_max = max(Ks)
    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)

    r = []
    for i in K_max_item_score:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc = 0.
    return r, auc

def test_one_user(x):
    # user u's ratings for user u
    is_val = x[-1]
    rating = x[0]
    #uid
    u = x[1]
    #user u's items in the training set
    try:
        training_items = dataloader.train_items[u]
    except Exception:
        training_items = []
    #user u's items in the test set
    if is_val:
        user_pos_test = dataloader.val_set[u]
    else:
        user_pos_test = dataloader.test_set[u]

    all_items = set(range(dataloader.n_items))

    test_items = list(all_items - set(training_items))
    r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)

    return get_performance(user_pos_test, r, auc, Ks)


def test_one_recall_user(x):
    # user u's ratings for user u
    is_val = x[2]
    rating = x[0]
    #uid
    u = x[1]
    #user u's items in the training set
    test_items = x[-1]

    #user u's items in the test set
    if is_val:
        user_pos_test = dataloader.val_set[u]
    else:
        user_pos_test = dataloader.test_set[u]


    r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks,True)

    return get_performance(user_pos_test, r, auc, Ks)

def GCN_test(test_user,is_val):
    GCN.eval()
    with torch.no_grad():
        user_emb, item_emb = GCN()
    result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),
              'hit_ratio': np.zeros(len(Ks)), 'auc': 0.}
    pool = multiprocessing.Pool(cores)
    count = 0
    n_user = len(test_user)

    n_batch = n_user//batch_size + 1
    for u_batch in range(n_batch):
        start = u_batch * batch_size
        end = min((u_batch + 1) * batch_size,n_user)
        user_batch = test_user[start: end]
        u_g_embeddings = user_emb[user_batch]
        i_g_embeddings = item_emb
        rat = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))
        #print(rat.shape)
        rat = rat.detach().cpu().numpy()
        rat_uid = zip(rat,user_batch,[is_val]*len(user_batch))
        batch_result = pool.map(test_one_user,rat_uid)
        count += len(batch_result)
        for re in batch_result:
            result['precision'] += re['precision'] / n_user
            result['recall'] += re['recall'] / n_user
            result['ndcg'] += re['ndcg'] / n_user
            result['hit_ratio'] += re['hit_ratio'] / n_user
            result['auc'] += re['auc'] / n_user
    assert count == n_user
    pool.close()
    return result


def test(test_user,test_items,max_inter,is_val):
    generater.eval()
    discriminator.eval()

    result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),
              'hit_ratio': np.zeros(len(Ks)), 'auc': 0.}
    pool = multiprocessing.Pool(cores)
    count = 0
    n_user = len(test_user)

    n_batch = n_user//batch_size + 1
    for u_batch in range(n_batch):
        start = u_batch * batch_size
        end = min((u_batch + 1) * batch_size,n_user)
        user_batch = test_user[start: end]
        test_batch = test_items[start:end]
        #print(len(user_batch),len(test_batch))
        with torch.no_grad():
            neg_user_feature, pos_user_featuer, item_feature = generater(user_batch, test_batch, max_inter, None ,True)
        #user_emb, item_emb = generater()
        neg_item_feature = item_feature[recall_item[user_batch]]
        rat,_ = get_rat(neg_user_feature, neg_item_feature, discriminator)
        #_, fake_item = get_rat(neg_user_feature, neg_item_feature, discriminator)
        #print(rat.shape)
        rat = rat.detach().cpu().numpy()
        rat_uid = zip(rat,user_batch,[is_val]*len(user_batch),recall_item[user_batch])
        batch_result = pool.map(test_one_recall_user,rat_uid)
        count += len(batch_result)
        for re in batch_result:
            result['precision'] += re['precision'] / n_user
            result['recall'] += re['recall'] / n_user
            result['ndcg'] += re['ndcg'] / n_user
            result['hit_ratio'] += re['hit_ratio'] / n_user
            result['auc'] += re['auc'] / n_user
    assert count == n_user
    pool.close()
    return result


def train():
    training_time_list = []
    loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
    best_recall = 0
    n_batch = dataloader.n_train//args.batch_size + 1
    stopping_step = 0
    for epoch in range(args.epoch):
        t1 = time()
        g_loss,d_f_loss,d_r_loss = 0., 0.,0.
        if epoch%3==0:
            discriminator.train()
            for batch in tqdm(range(n_batch)):
                optim_D.zero_grad()
                train_user, pos_item, train_items, max_inter = dataloader.sample()

                neg_user_feature, pos_user_featuer, item_feature = generater(train_user, train_items, max_inter,
                                                                             pos_item)
                # real_feature = [user_feature[i,pos_item[i],:] for i,user in enumerate(train_user)]
                real_input = torch.cat([pos_user_featuer.squeeze(), item_feature[pos_item]], 1)
                # real_input = torch.cat([user_emb[train_user],item_emb[pos_item]],1)
                neg_item_feature = item_feature[recall_item[train_user]]

                real = discriminator(real_input.detach())
                # print("real: ",real)
                real_loss = Dis_loss(real, torch.ones_like(real))

                _, fake_item = get_rat(neg_user_feature, neg_item_feature, discriminator)

                feak_user_feature = [neg_user_feature[i, fake_item[i], :] for i, user in enumerate(train_user)]
                feak_item_feature = [neg_item_feature[i, fake_item[i], :] for i, user in enumerate(train_user)]

                fake_input = torch.cat([torch.vstack(feak_user_feature), torch.vstack(feak_item_feature)], 1)
                fake = discriminator(fake_input.detach())
                # print("fake: ",fake)
                fake_loss = Dis_loss(fake, torch.zeros_like(fake))
                Dloss = (real_loss + fake_loss) / 2
                Dloss.backward()
                optim_D.step()
                d_f_loss += fake_loss.detach().cpu().numpy()
                d_r_loss += real_loss.detach().cpu().numpy()
        else:
            generater.train()
            for batch in tqdm(range(n_batch)):
                train_user, pos_item, train_items, max_inter = dataloader.sample()
                neg_user_feature, pos_user_featuer, item_feature = generater(train_user, train_items, max_inter,
                                                                             pos_item)
                neg_item_feature = item_feature[recall_item[train_user]]
                _, fake_item = get_rat(neg_user_feature, neg_item_feature, discriminator)
                feak_user_feature = [neg_user_feature[i, fake_item[i], :] for i, user in enumerate(train_user)]
                feak_item_feature = [neg_item_feature[i, fake_item[i], :] for i, user in enumerate(train_user)]
                fake_input = torch.cat([torch.vstack(feak_user_feature), torch.vstack(feak_item_feature)], 1)
                optim_G.zero_grad()
                new_fake = discriminator(fake_input)

                Gloss = Dis_loss(new_fake,torch.ones_like(new_fake))
                Gloss.backward()
                optim_G.step()
                g_loss += Gloss.detach().cpu().numpy()
        t2 = time()
        test_user = list(dataloader.test_set.keys())
        val_user = list(dataloader.val_set.keys())
        test_item = dataloader.test_items
        val_item = dataloader.val_items
        ret = test(val_user,val_item,max_inter,True)
        t3 = time()
        loss_loger.append((d_r_loss+d_f_loss)/2+g_loss)
        rec_loger.append(ret['recall'].data)
        pre_loger.append(ret['precision'].data)
        ndcg_loger.append(ret['ndcg'].data)
        hit_loger.append(ret['hit_ratio'].data)

        # line_var_recall.append(ret['recall'][1])
        # line_var_precision.append(ret['precision'][1])
        # line_var_ndcg.append(ret['ndcg'][1])
        perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f, %.5f], ' \
                   'precision=[%.5f, %.5f, %.5f], hit=[%.5f, %.5f, %.5f], ndcg=[%.5f, %.5f, %.5f]' % \
                   (epoch, t2 - t1, t3 - t2, g_loss+(d_f_loss+d_r_loss)/2, g_loss, d_f_loss, d_r_loss, ret['recall'][0], ret['recall'][1],
                    ret['recall'][2],
                    ret['precision'][0], ret['precision'][1], ret['precision'][2],
                    ret['hit_ratio'][0], ret['hit_ratio'][1], ret['hit_ratio'][2],
                    ret['ndcg'][0], ret['ndcg'][1], ret['ndcg'][2])

        logger.logging(perf_str)

        if ret['recall'][1] > best_recall:
            best_recall = ret['recall'][1]
            test_ret = test(test_user,test_item,max_inter,is_val=False)
            logger.logging("Test_Recall@%d: %.5f,  precision=[%.5f], ndcg=[%.5f]" % (
            eval(args.Ks)[1], test_ret['recall'][1], test_ret['precision'][1], test_ret['ndcg'][1]))
            stopping_step = 0
        elif stopping_step < args.early_stop:
            stopping_step += 1
            logger.logging('#####Early stopping steps: %d #####' % stopping_step)
        else:
            logger.logging('#####Early stop! #####')
    result_logger.logging("Test_Recall@%d: %.5f,  precision=[%.5f], ndcg=[%.5f], epoch=%d " % (
                eval(args.Ks)[1], test_ret['recall'][1], test_ret['precision'][1], test_ret['ndcg'][1], epoch))


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def bpr_loss_calculate(users, pos_items, neg_items):
    pos_scores = torch.sum(torch.mul(users, pos_items), dim=1)
    neg_scores = torch.sum(torch.mul(users, neg_items), dim=1)

    regularizer = 1. / 2 * (users ** 2).sum() + 1. / 2 * (pos_items ** 2).sum() + 1. / 2 * (neg_items ** 2).sum()
    regularizer = regularizer / args.batch_size

    maxi = F.logsigmoid(pos_scores - neg_scores)
    mf_loss = -torch.mean(maxi)

    emb_loss = args.decay * regularizer
    reg_loss = 0.0
    return mf_loss, emb_loss, reg_loss


def GCN_train():
    saved_itme_emb, saved_user_emb = None, None
    optim = torch.optim.Adam(GCN.parameters(),lr=args.lr)
    GCN.apply(initialize_weights)
    #Dis_loss = nn.BCELoss(reduction="mean")
    GCN.cuda()
    training_time_list = []
    line_bpr_loss, line_reg_loss, line_var_recall, line_var_precision, line_var_ndcg = [], [], [], [], []
    bpr_loss_loger,reg_loss_logger,pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], [],[]
    best_recall = 0
    n_batch = dataloader.n_train//args.batch_size + 1
    stopping_step = 0
    for epoch in range(args.epoch):
        t1 = time()
        batch_bpr_loss,batch_bpr_reg = 0., 0.
        GCN.train()
        for batch in tqdm(range(n_batch)):
            optim.zero_grad()
            train_user,pos_item,neg_item = dataloader.GCN_sample()
            user_emb,item_emb = GCN()
            mf_loss,reg_loss,_ = bpr_loss_calculate(user_emb[train_user],item_emb[pos_item],item_emb[neg_item])
            loss = mf_loss+reg_loss
            loss.backward()
            optim.step()
            batch_bpr_loss += mf_loss.detach().cpu().numpy()
            batch_bpr_reg += reg_loss.detach().cpu().numpy()
        t2 = time()
        test_user = list(dataloader.test_set.keys())
        val_user = list(dataloader.val_set.keys())
        ret = GCN_test(val_user,True)
        t3 = time()
        bpr_loss_loger.append(batch_bpr_loss)
        reg_loss_logger.append(batch_bpr_reg)
        rec_loger.append(ret['recall'].data)
        pre_loger.append(ret['precision'].data)
        ndcg_loger.append(ret['ndcg'].data)
        hit_loger.append(ret['hit_ratio'].data)

        line_var_recall.append(ret['recall'][1])
        line_var_precision.append(ret['precision'][1])
        line_var_ndcg.append(ret['ndcg'][1])
        # line_var_recall.append(ret['recall'][1])
        # line_var_precision.append(ret['precision'][1])
        # line_var_ndcg.append(ret['ndcg'][1])
        perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f, %.5f], ' \
                   'precision=[%.5f, %.5f, %.5f], hit=[%.5f, %.5f, %.5f], ndcg=[%.5f, %.5f, %.5f]' % \
                   (epoch, t2 - t1, t3 - t2, batch_bpr_loss+batch_bpr_reg, batch_bpr_loss, batch_bpr_reg, 0., ret['recall'][0], ret['recall'][1],
                    ret['recall'][2],
                    ret['precision'][0], ret['precision'][1], ret['precision'][2],
                    ret['hit_ratio'][0], ret['hit_ratio'][1], ret['hit_ratio'][2],
                    ret['ndcg'][0], ret['ndcg'][1], ret['ndcg'][2])

        logger.logging(perf_str)

        if ret['recall'][1] > best_recall:
            best_recall = ret['recall'][1]
            test_ret = GCN_test(test_user, is_val=False)
            logger.logging("Test_Recall@%d: %.5f,  precision=[%.5f], ndcg=[%.5f]" % (
            eval(args.Ks)[1], test_ret['recall'][1], test_ret['precision'][1], test_ret['ndcg'][1]))
            stopping_step = 0
            saved_itme_emb = GCN.item_emb
            saved_user_emb = GCN.user_emb
        elif stopping_step < args.early_stop:
            stopping_step += 1
            logger.logging('#####Early stopping steps: %d #####' % stopping_step)
        else:
            logger.logging('##### Recall Finish! #####')
            rat = torch.matmul(saved_user_emb.weight, torch.transpose(saved_itme_emb.weight, 0, 1))
            _,item_index = rat.topk(round(args.Recall_rate*len(item_emb)),dim=1)
            return saved_user_emb,saved_itme_emb,item_index.detach()
    result_logger.logging("Test_Recall@%d: %.5f,  precision=[%.5f], ndcg=[%.5f], epoch=%d " % (
                eval(args.Ks)[1], test_ret['recall'][1], test_ret['precision'][1], test_ret['ndcg'][1], epoch))

if __name__ == '__main__':
    batch_size = args.batch_size
    cores = multiprocessing.cpu_count() // 5

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    set_seed(2023)

    task_name = "%s_%s" % (datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), args.dataset,)
    logger = Logger(filename=task_name)
    logger.logging(str(args))
    result_logger = Logger(filename="first_test.txt")

    test_user = list(dataloader.test_set.keys())
    val_user = list(dataloader.val_set.keys())

    GCN = Light_GCN(args.emb_size)
    # user_emb,item_emb,recall_item = GCN_train()
    # np.save('emb/tik_uer_emb',user_emb.weight.detach().cpu().numpy())
    # np.save('emb/tik_item_emb', item_emb.weight.detach().cpu().numpy())
    # np.save('emb/tik_recall_item', recall_item.cpu().numpy())
    user_emb = torch.from_numpy(np.load('uer_emb.npy')).cuda()
    item_emb = torch.from_numpy(np.load('item_emb.npy')).cuda()
    #rat = torch.matmul(user_emb, torch.transpose(item_emb, 0, 1))
    #_, recall_item = rat.topk(round(args.Recall_rate * len(item_emb)), dim=1)
    recall_item = torch.from_numpy(np.load('recall_item.npy'))
    discriminator = Discriminator(args.emb_size * 4)
    generater = Generater(args.emb_size,user_emb,item_emb,recall_item)
    optim_D = optim.Adam(discriminator.parameters(),lr=args.D_lr)
    optim_G = optim.Adam(generater.parameters(),lr = args.G_lr)
    generater.apply(initialize_weights)
    discriminator.apply(initialize_weights)
    Dis_loss = nn.BCELoss(reduction="mean")
    generater.cuda()
    discriminator.cuda()


    train()
